{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "keras.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "# Distributed Training with Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "**Learning Objectives**\n",
        "\n",
        "1. How to define distribution strategy and set input pipeline.\n",
        "2. How to create the Keras model.\n",
        "3. How to define the callbacks.\n",
        "4. How to train and evaluate the model.\n",
        "\n",
        "## Introduction\n",
        "\n",
        "The `tf.distribute.Strategy` API provides an abstraction for distributing your training\n",
        "across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.\n",
        "\n",
        "This notebook uses the `tf.distribute.MirroredStrategy`, which\n",
        "does in-graph replication with synchronous training on many GPUs on one machine.\n",
        "Essentially, it copies all of the model's variables to each processor.\n",
        "Then, it uses [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) to combine the gradients from all processors and applies the combined value to all copies of the model.\n",
        "\n",
        "`MirroredStrategy` is one of several distribution strategy available in TensorFlow core. You can read about more strategies at [distribution strategy guide](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/guide/distributed_training.ipynb).\n",
        "\n",
        "Each learning objective will correspond to a __#TODO__ in the [student lab notebook](../labs/keras.ipynb) -- try to complete that notebook first before reviewing this solution notebook. \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "### Keras API\n",
        "\n",
        "This example uses the `tf.keras` API to build the model and training loop. For custom training loops, see the [tf.distribute.Strategy with training loops](training_loops.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dney9v7BsJij"
      },
      "source": [
        "## Import dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r8S3ublR7Ay8"
      },
      "source": [
        "# Import TensorFlow and TensorFlow Datasets\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow as tf\n",
        "\n",
        "import os"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SkocY8tgRd3H",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "93f81b8a-4cd1-4a44-acaf-b2c3fd50d404"
      },
      "source": [
        "# Here we'll show the currently installed version of TensorFlow\n",
        "print(tf.__version__)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2.6.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hXhefksNKk2I"
      },
      "source": [
        "## Download the dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OtnnUwvmB3X5"
      },
      "source": [
        "Download the MNIST dataset and load it from [TensorFlow Datasets](https://www.tensorflow.org/datasets). This returns a dataset in `tf.data` format."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lHAPqG8MtS8M"
      },
      "source": [
        "Setting `with_info` to `True` includes the metadata for the entire dataset, which is being saved here to `info`.\n",
        "Among other things, this metadata object includes the number of train and test examples. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iXMJ3G9NB3X6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 222,
          "referenced_widgets": [
            "87a63a803ef64822bd91f083558278ce",
            "29cd8bd6a1654ddb890f8e26afbbb574",
            "954d6bb981e64053a56b71336c7c76b6",
            "b1a06090e773431e9a7f0c2c89ecda19",
            "5dfab89c78d14363bd424513d9bc8a7e",
            "84909d8b4a9c4c8586319b7f2482e3a7",
            "4dea4d007cac494da0526c363d96e94b",
            "eac25bd6c24144a9be4cdc3c46d3031f"
          ]
        },
        "outputId": "ad3818a6-0c79-4b89-82ff-ec63f9ff0234"
      },
      "source": [
        "# Loads the named dataset into a tf.data.Dataset\n",
        "# TODO\n",
        "datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)\n",
        "\n",
        "mnist_train, mnist_test = datasets['train'], datasets['test']"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/jupyter/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "87a63a803ef64822bd91f083558278ce",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "\n",
            "\u001b[1mDataset mnist downloaded and prepared to /home/jupyter/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "2021-09-14 12:42:12.088968: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GrjVhv-eKuHD"
      },
      "source": [
        "## Define distribution strategy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TlH8vx6BB3X9"
      },
      "source": [
        "Create a `MirroredStrategy` object. This will handle distribution, and provides a context manager (`tf.distribute.MirroredStrategy.scope`) to build your model inside."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4j0tdf4YB3X9",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e20ea159-5088-4776-be9e-a4a71787a7bf"
      },
      "source": [
        "# TODO\n",
        "# Synchronous training across multiple replicas on one machine.\n",
        "strategy = tf.distribute.MirroredStrategy()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cY3KA_h2iVfN",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "eaf4aec4-694f-4005-a08e-7b62a84e349a"
      },
      "source": [
        "print('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Number of devices: 1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lNbPv0yAleW8"
      },
      "source": [
        "## Setup input pipeline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "psozqcuptXhK"
      },
      "source": [
        "When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p1xWxKcnhar9"
      },
      "source": [
        "# You can also do info.splits.total_num_examples to get the total\n",
        "# number of examples in the dataset.\n",
        "\n",
        "num_train_examples = info.splits['train'].num_examples\n",
        "num_test_examples = info.splits['test'].num_examples\n",
        "\n",
        "BUFFER_SIZE = 10000\n",
        "\n",
        "BATCH_SIZE_PER_REPLICA = 64\n",
        "BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Wm5rsL2KoDF"
      },
      "source": [
        "Pixel values, which are 0-255, [have to be normalized to the 0-1 range](https://en.wikipedia.org/wiki/Feature_scaling). Define this scale in a function."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Eo9a46ZeJCkm"
      },
      "source": [
        "def scale(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image /= 255\n",
        "\n",
        "  return image, label"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WZCa5RLc5A91"
      },
      "source": [
        "Apply this function to the training and test data, shuffle the training data, and [batch it for training](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch). Notice we are also keeping an in-memory cache of the training data to improve performance.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gRZu2maChwdT"
      },
      "source": [
        "train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
        "eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4xsComp8Kz5H"
      },
      "source": [
        "## Create the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1BnQYQTpB3YA"
      },
      "source": [
        "Create and compile the Keras model in the context of `strategy.scope`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IexhL_vIB3YA",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7ea5ab8a-c328-4a14-bd79-8ba35e598f7d"
      },
      "source": [
        "with strategy.scope():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(64, activation='relu'),\n",
        "      tf.keras.layers.Dense(10)\n",
        "  ])\n",
        "\n",
        "# Configures the model for training.\n",
        "# TODO\n",
        "  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=['accuracy'])"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8i6OU5W9Vy2u"
      },
      "source": [
        "## Define the callbacks\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YOXO5nvvK3US"
      },
      "source": [
        "The callbacks used here are:\n",
        "\n",
        "*   *TensorBoard*: This callback writes a log for TensorBoard which allows you to visualize the graphs.\n",
        "*   *Model Checkpoint*: This callback saves the model after every epoch.\n",
        "*   *Learning Rate Scheduler*: Using this callback, you can schedule the learning rate to change after every epoch/batch.\n",
        "\n",
        "For illustrative purposes, add a print callback to display the *learning rate* in the notebook."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A9bwLCcXzSgy"
      },
      "source": [
        "# Define the checkpoint directory to store the checkpoints\n",
        "\n",
        "# TODO\n",
        "checkpoint_dir = './training_checkpoints'\n",
        "# Name of the checkpoint files\n",
        "# TODO\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wpU-BEdzJDbK"
      },
      "source": [
        "# Function for decaying the learning rate.\n",
        "# You can define any decay function you need.\n",
        "def decay(epoch):\n",
        "  if epoch < 3:\n",
        "    return 1e-3\n",
        "  elif epoch >= 3 and epoch < 7:\n",
        "    return 1e-4\n",
        "  else:\n",
        "    return 1e-5"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jKhiMgXtKq2w"
      },
      "source": [
        "# Callback for printing the LR at the end of each epoch.\n",
        "class PrintLR(tf.keras.callbacks.Callback):\n",
        "  def on_epoch_end(self, epoch, logs=None):\n",
        "    print('\\nLearning rate for epoch {} is {}'.format(epoch + 1,\n",
        "                                                      model.optimizer.lr.numpy()))"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YVqAbR6YyNQh"
      },
      "source": [
        "callbacks = [\n",
        "    tf.keras.callbacks.TensorBoard(log_dir='./logs'),\n",
        "    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,\n",
        "                                       save_weights_only=True),\n",
        "    tf.keras.callbacks.LearningRateScheduler(decay),\n",
        "    PrintLR()\n",
        "]"
      ],
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "70HXgDQmK46q"
      },
      "source": [
        "## Train and evaluate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6EophnOAB3YD"
      },
      "source": [
        "Now, train the model in the usual way, calling `fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7MVw_6CqB3YD",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1d72cbdc-11aa-4beb-eeb4-dcfcba5fce72"
      },
      "source": [
        "# Train the model with the new callback\n",
        "# TODO\n",
        "model.fit(train_dataset, epochs=12, callbacks=callbacks)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1/12\n",
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "  3/938 [..............................] - ETA: 41s - loss: 2.2821 - accuracy: 0.1172    WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0062s vs `on_train_batch_end` time: 0.0128s). Check your callbacks.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0062s vs `on_train_batch_end` time: 0.0128s). Check your callbacks.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "938/938 [==============================] - 16s 7ms/step - loss: 0.4077 - accuracy: 0.8858\n",
            "\n",
            "Learning rate for epoch 1 is 0.0010000000474974513\n",
            "Epoch 2/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0711 - accuracy: 0.9787\n",
            "\n",
            "Learning rate for epoch 2 is 0.0010000000474974513\n",
            "Epoch 3/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0459 - accuracy: 0.9861\n",
            "\n",
            "Learning rate for epoch 3 is 0.0010000000474974513\n",
            "Epoch 4/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0279 - accuracy: 0.9921\n",
            "\n",
            "Learning rate for epoch 4 is 9.999999747378752e-05\n",
            "Epoch 5/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0237 - accuracy: 0.9935\n",
            "\n",
            "Learning rate for epoch 5 is 9.999999747378752e-05\n",
            "Epoch 6/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0211 - accuracy: 0.9943\n",
            "\n",
            "Learning rate for epoch 6 is 9.999999747378752e-05\n",
            "Epoch 7/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0204 - accuracy: 0.9948\n",
            "\n",
            "Learning rate for epoch 7 is 9.999999747378752e-05\n",
            "Epoch 8/12\n",
            "938/938 [==============================] - 4s 4ms/step - loss: 0.0179 - accuracy: 0.9953\n",
            "\n",
            "Learning rate for epoch 8 is 9.999999747378752e-06\n",
            "Epoch 9/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0181 - accuracy: 0.9954\n",
            "\n",
            "Learning rate for epoch 9 is 9.999999747378752e-06\n",
            "Epoch 10/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0168 - accuracy: 0.9959\n",
            "\n",
            "Learning rate for epoch 10 is 9.999999747378752e-06\n",
            "Epoch 11/12\n",
            "938/938 [==============================] - 4s 4ms/step - loss: 0.0166 - accuracy: 0.9958\n",
            "\n",
            "Learning rate for epoch 11 is 9.999999747378752e-06\n",
            "Epoch 12/12\n",
            "938/938 [==============================] - 3s 4ms/step - loss: 0.0162 - accuracy: 0.9959\n",
            "\n",
            "Learning rate for epoch 12 is 9.999999747378752e-06\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7f88eeb0cd68>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NUcWAUUupIvG"
      },
      "source": [
        "As you can see below, the checkpoints are getting saved."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JQ4zeSTxKEhB",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6468c115-bac6-4952-bb16-a85f01297388"
      },
      "source": [
        "# check the checkpoint directory\n",
        "!ls {checkpoint_dir}"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "checkpoint\t\t     ckpt_4.data-00000-of-00001\n",
            "ckpt_10.data-00000-of-00001  ckpt_4.index\n",
            "ckpt_10.index\t\t     ckpt_5.data-00000-of-00001\n",
            "ckpt_11.data-00000-of-00001  ckpt_5.index\n",
            "ckpt_11.index\t\t     ckpt_6.data-00000-of-00001\n",
            "ckpt_12.data-00000-of-00001  ckpt_6.index\n",
            "ckpt_12.index\t\t     ckpt_7.data-00000-of-00001\n",
            "ckpt_1.data-00000-of-00001   ckpt_7.index\n",
            "ckpt_1.index\t\t     ckpt_8.data-00000-of-00001\n",
            "ckpt_2.data-00000-of-00001   ckpt_8.index\n",
            "ckpt_2.index\t\t     ckpt_9.data-00000-of-00001\n",
            "ckpt_3.data-00000-of-00001   ckpt_9.index\n",
            "ckpt_3.index\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qor53h7FpMke"
      },
      "source": [
        "To see how the model perform, load the latest checkpoint and call `evaluate` on the test data.\n",
        "\n",
        "Call `evaluate` as before using appropriate datasets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JtEwxiTgpQoP",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4c0cd338-ce12-4715-8806-2fcdbfe8ecd8"
      },
      "source": [
        "# Loads the weights\n",
        "model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
        "\n",
        "eval_loss, eval_acc = model.evaluate(eval_dataset)\n",
        "\n",
        "print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "157/157 [==============================] - 3s 8ms/step - loss: 0.0381 - accuracy: 0.9874\n",
            "Eval loss: 0.03809080645442009, Eval Accuracy: 0.9873999953269958\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IIeF2RWfYu4N"
      },
      "source": [
        "To see the output, you can download and view the TensorBoard logs at the terminal.\n",
        "\n",
        "```\n",
        "$ tensorboard --logdir=path/to/log-directory\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LnyscOkvKKBR",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "624c3432-b986-459c-aad1-024b13d74e10"
      },
      "source": [
        "!ls -sh ./logs"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total 4.0K\n",
            "4.0K train\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kBLlogrDvMgg"
      },
      "source": [
        "## Export to SavedModel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xa87y_A0vRma"
      },
      "source": [
        "Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h8Q4MKOLwG7K"
      },
      "source": [
        "path = 'saved_model/'"
      ],
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4HvcDmVsvQoa",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "35b708b8-8829-42cd-c796-375fd17f0752"
      },
      "source": [
        "# Save the entire model as a SavedModel.\n",
        "# TODO\n",
        "model.save(path, save_format='tf')"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: saved_model/assets\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: saved_model/assets\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vKJT4w5JwVPI"
      },
      "source": [
        "Load the model without `strategy.scope`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T_gT0RbRvQ3o",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3f77b369-3132-4407-f217-57e612f22362"
      },
      "source": [
        "unreplicated_model = tf.keras.models.load_model(path)\n",
        "\n",
        "unreplicated_model.compile(\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer=tf.keras.optimizers.Adam(),\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)\n",
        "\n",
        "print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "157/157 [==============================] - 1s 3ms/step - loss: 0.0361 - accuracy: 0.9892\n",
            "Eval loss: 0.03809080645442009, Eval Accuracy: 0.9873999953269958\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YBLzcRF0wbDe"
      },
      "source": [
        "Load the model with `strategy.scope`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BBVo3WGGwd9a",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "398d611e-1407-4f76-bd7c-aa33fc907023"
      },
      "source": [
        "# Recreate the exact same model, including its weights and the optimizer\n",
        "with strategy.scope():\n",
        "  replicated_model = tf.keras.models.load_model(path)\n",
        "  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                           optimizer=tf.keras.optimizers.Adam(),\n",
        "                           metrics=['accuracy'])\n",
        "\n",
        "  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)\n",
        "  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "157/157 [==============================] - 3s 4ms/step - loss: 0.0361 - accuracy: 0.9892\n",
            "Eval loss: 0.03809080645442009, Eval Accuracy: 0.9873999953269958\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUZwaz4AKjtD"
      },
      "source": [
        "### Examples and Tutorials\n",
        "Here are some examples for using distribution strategy with keras fit/compile:\n",
        "1. [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py) example trained using `tf.distribute.MirroredStrategy`\n",
        "2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) example trained using `tf.distribute.MirroredStrategy`.\n",
        "\n",
        "More examples listed in the [Distribution strategy guide](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/guide/distributed_training.ipynb#examples_and_tutorials)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8uNqWRdDMl5S"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "* Read the [distribution strategy guide](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/guide/distributed_training.ipynb).\n",
        "* Read the [Distributed Training with Custom Training Loops](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/distribute/custom_training.ipynb) tutorial.\n",
        "* Visit the [Performance section](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/guide/function.ipynb) in the guide to learn more about other strategies and [tools](https://raw.githubusercontent.com/tensorflow/docs/master/site/en/guide/profiler.md) you can use to optimize the performance of your TensorFlow models.\n",
        "\n",
        "Note: `tf.distribute.Strategy` is actively under development and we will be adding more examples and tutorials in the near future. Please give it a try. We welcome your feedback via [issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new)."
      ]
    }
  ]
}
